import os
import pandas as pd
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, Dataset
from loguru import logger

TASKS = [
    'abstract_algebra',
    'anatomy',
    'astronomy',
    'business_ethics',
    'clinical_knowledge',
    'college_biology',
    'college_chemistry',
    'college_computer_science',
    'college_mathematics',
    'college_medicine',
    'college_physics',
    'computer_security',
    'conceptual_physics',
    'econometrics',
    'electrical_engineering',
    'elementary_mathematics',
    'formal_logic',
    'global_facts',
    'high_school_biology',
    'high_school_chemistry',
    'high_school_computer_science',
    'high_school_european_history',
    'high_school_geography',
    'high_school_government_and_politics',
    'high_school_macroeconomics',
    'high_school_mathematics',
    'high_school_microeconomics',
    'high_school_physics',
    'high_school_psychology',
    'high_school_statistics',
    'high_school_us_history',
    'high_school_world_history',
    'human_aging',
    'human_sexuality',
    'international_law',
    'jurisprudence',
    'logical_fallacies',
    'machine_learning',
    'management',
    'marketing',
    'medical_genetics',
    'miscellaneous',
    'moral_disputes',
    'moral_scenarios',
    'nutrition',
    'philosophy',
    'prehistory',
    'professional_accounting',
    'professional_law',
    'professional_medicine',
    'professional_psychology',
    'public_relations',
    'security_studies',
    'sociology',
    'us_foreign_policy',
    'virology',
    'world_religions'
]
choices = ["A", "B", "C", "D"]


def format_subject(subject):
    l = subject.split("_")
    s = ""
    for entry in l:
        s += " " + entry
    return s


def gen_prompt(train_df, subject, k=-1):
    prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
        format_subject(subject))
    if k == -1:
        k = train_df.shape[0]
    for i in range(k):
        prompt += format_example(train_df, i)
    return prompt


def format_example(df, idx, include_answer=True):
    prompt = df.iloc[idx, 0]
    k = df.shape[1] - 2
    for j in range(k):
        prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
    prompt += "\nAnswer:"
    if include_answer:
        prompt += " {}\n\n".format(df.iloc[idx, k + 1])
    return prompt


class mmlu(Dataset):

    def __init__(self, config):
        self.tokenizer = AutoTokenizer.from_pretrained(config.data_dir + "/" +
                                                       config.weight_dir)
        self.records = []
        self.length = 0

        for task in TASKS:

            logger.debug("Loading 5-shot " + str(task))
            mmlu_datapath = config.data_dir + "/" + config.mmlu_dir
            dev_df = pd.read_csv(os.path.join(mmlu_datapath, "dev",
                                              task + "_dev.csv"),
                                 header=None)[:config.few_shots]
            test_df = pd.read_csv(os.path.join(mmlu_datapath, "test",
                                               task + "_test.csv"),
                                  header=None)

            for i in range(test_df.shape[0]):
                k = config.few_shots
                prompt_end = format_example(test_df, i, include_answer=False)
                train_prompt = gen_prompt(dev_df, task, k)
                prompt = train_prompt + prompt_end
                while len(self.tokenizer.tokenize(
                        prompt)) + 1 > config.seq_length:
                    prompt_split = prompt.split("\n\n")
                    prompt_split.pop(1)
                    prompt = "\n\n".join(prompt_split)
                label = test_df.iloc[i, test_df.shape[1] - 1]
                token_prompt = self.tokenizer(prompt, return_tensors="pt")
                token_label = self.tokenizer([label], return_tensors="pt")
                self.records.append({
                    "prompt": token_prompt,
                    "answer": token_label.input_ids
                })
                self.length += 1

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return self.records[idx]


def build_dataloader(config, tokenizer):
    dataset = mmlu(config)
    loader = DataLoader(dataset,
                        batch_size=1,
                        shuffle=False,
                        drop_last=False,
                        num_workers=config.num_workers_dataloader,
                        pin_memory=True)

    return loader
